from State.buffer import VectorGCReplayBufferManager, GCReplayBuffer
import numpy as np
from typing import Any, List, Tuple, Union, Optional, Callable
from collections import deque
import copy, time

import torch
import gymnasium as gym
import numpy as np
from numba import njit
from tianshou.data import (Batch, SegmentTree, to_numpy,
                           ReplayBuffer, PrioritizedReplayBuffer, HERReplayBuffer,
                           ReplayBufferManager, PrioritizedReplayBufferManager)

import scipy
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from State.utils import Running
from Policy.hindsight_filter import IGNORE_FIRST

def her_summary_statistics(her_add_statistics):
    if "her/false_positive" in her_add_statistics:
        her_add_statistics["her/false_positive"] = her_add_statistics["her/false_positive"] / max(1, her_add_statistics["her/false_positive"] + her_add_statistics["her/true_reached"])
        her_add_statistics["her/false_negative"] = her_add_statistics["her/false_negative"] / max(1, her_add_statistics["her/false_negative"] + her_add_statistics["her/num_ended"] - her_add_statistics["her/true_reached"])
        her_add_statistics["her/match_true"] = her_add_statistics["her/match_true"] / max(1, her_add_statistics["her/num_ended"])
        her_add_statistics["her/positive_rate"] = her_add_statistics["her/true_reached"] / max(1, her_add_statistics["her/num_ended"])
        her_add_statistics["her/true_graph_totals"] = np.mean(her_add_statistics["her/true_graph_totals"])
        her_add_statistics["her/graph_totals"] = np.mean(her_add_statistics["her/graph_totals"])
        her_add_statistics["her/total_positive"] = np.sum(her_add_statistics["her/total_positive"]) / max(1, np.sum(her_add_statistics["her/reached"]))
        her_add_statistics["her/total_negative"] = np.sum(her_add_statistics["her/total_negative"]) / max(1, np.sum(her_add_statistics["her/reached"]))
    else:
        her_add_statistics["her/false_positive"] = 0
        her_add_statistics["her/false_negative"] = 0
        her_add_statistics["her/positive_rate"] = 0
        her_add_statistics["her/match_true"] = 0
        her_add_statistics["her/true_reached"] = 0
        her_add_statistics["her/true_graph_totals"] = 0
        her_add_statistics["her/graph_totals"] = 0
    return her_add_statistics

def get_last_achieved(reached_idx, achieved_goal, use_lowest_post = False):
    if use_lowest_post:
        achieved_x = achieved_goal[reached_idx:,0]
        lowest_x_idx = np.argmin(achieved_x)
        # print(achieved_x, lowest_x_idx, reached_idx)
        last_achieved = copy.deepcopy(achieved_goal[lowest_x_idx + reached_idx])
    else:
        last_achieved = copy.deepcopy(achieved_goal[-1])
    return last_achieved


def add_her_trajectory(buffer, temp_buffers, temp_buffer_ptrs, done_indices, her_trajectory_check, check_rew, check_term, num_samples= 1, her_traj_length=-1, use_lowest_post=False, episode_counter=-1):
    ended = [temp_buffers[di][:temp_buffer_ptrs[di]] for di in done_indices]
    num_ended = len(ended)

    # keep some statistics about the trjaectories
    her_add_statistics = {"her/num_ended": num_ended, 
                          "her/reached": 0, 
                          "her/num_traj": 0, 
                          "her/graph_totals": list(),
                          "her/true_graph_totals": list(),
                        "her/match_true": 0,
                        "her/false_positive": 0,
                        "her/false_negative": 0,
                        "her/true_reached": 0,
                        "her/total_positive": list(),
                        "her/total_negative": list(),
                        "her/achieved_desired": list(),
                        }
    for trajid, traj in enumerate(ended):
        success, reached_idxes, graph_totals, true_success, true_idxes, true_totals = her_trajectory_check(traj)
        
        # printouts to debug the interactions
        # print("idxes in trajectory", true_idxes, reached_idxes)
        # for idx in true_idxes:
        #     if idx not in reached_idxes:
        #         print("false negative", traj[idx].obs, traj[idx].obs_next)
        # for idx in reached_idxes:
        #     if idx not in true_idxes:
        #         print("false positive", traj[idx].obs, traj[idx].obs_next)


        her_add_statistics["her/match_true"] += int(success == true_success)
        her_add_statistics["her/false_positive"] += int(success != true_success and success)
        her_add_statistics["her/false_negative"] += int(success != true_success and not success)

        her_add_statistics["her/reached"] += int(success)
        her_add_statistics["her/true_reached"] += int(true_success)
        her_add_statistics["her/graph_totals"].append(graph_totals)
        her_add_statistics["her/true_graph_totals"].append(graph_totals)
        her_add_statistics["her/achieved_desired"].append([[success], traj.obs.achieved_goal[IGNORE_FIRST] - traj.obs.desired_goal[0]])
        if true_success: her_add_statistics["her/total_positive"].append(graph_totals)
        else: her_add_statistics["her/total_negative"].append(graph_totals)
        num_traj = 0
        last_achieved = get_last_achieved(0, traj.obs.achieved_goal, use_lowest_post=use_lowest_post)
        # print("adding traj", success, true_success, graph_totals, np.sum(traj.graph, axis=0))
        if success:
            # always add in the last state as a hindsight goal
            reached_idx = reached_idxes[0]
            use_start = max(0,len(traj) - her_traj_length - 1) if her_traj_length > 0 else 0

            # if use_lowest_post:
            #     achieved_x = traj.obs.achieved_goal[reached_idx:,0]
            #     lowest_x_idx = np.argmin(achieved_x)
            #     # print(achieved_x, lowest_x_idx, reached_idx)
            #     last_achieved = copy.deepcopy(traj[lowest_x_idx + reached_idx].obs.achieved_goal)
            # else:
            #     last_achieved = copy.deepcopy(traj[-1].obs.achieved_goal)
            last_achieved = get_last_achieved(reached_idx, traj.obs.achieved_goal, use_lowest_post=use_lowest_post)
            # print("achieved graph")
            for i in range(use_start, len(traj)):
                new_val = copy.deepcopy(traj[i])
                new_val.obs = Batch(desired_goal= last_achieved, achieved_goal = traj[i].obs.achieved_goal, observation = traj[i].obs.observation, reached_graph_counter=traj[i].obs.reached_graph_counter)
                new_val.obs_next = Batch(desired_goal= last_achieved, achieved_goal = traj[i].obs_next.achieved_goal, observation = traj[i].obs_next.observation, reached_graph_counter=traj[i].obs_next.reached_graph_counter)
                new_val.rew = check_rew(new_val)
                new_val.terminate = check_term(new_val)
                # print(i, last_achieved, traj[i].obs.achieved_goal,traj[i].obs.observation, new_val.rew, traj[i].obs.reached_graph_counter)
                buffer.add_her(new_val)
            num_traj += 1
            
            # add in as many additional trajectories as num_samples - 1
            for samp in range(min(len(traj) - reached_idx, num_samples - 1)):
                samp_idx = np.random.randint(reached_idx, len(traj))
                # put the sampled index in the middle of the index if using traj length
                use_start = int(samp_idx - her_traj_length / 2) if her_traj_length > 0 else 0
                use_end = int(samp_idx + her_traj_length / 2) if her_traj_length > 0 else len(traj)
                for i in range(use_start, use_end):
                    new_val = copy.deepcopy(traj[i])
                    new_val.obs.desired_goal = traj[samp_idx].obs.achieved_goal
                    new_val.rew = check_rew(new_val)
                    new_val.terminate = check_term(new_val)
                    buffer.add_her(new_val)
                num_traj += 1
        # else:
        #     if np.linalg.norm(last_achieved - traj.obs.achieved_goal[IGNORE_FIRST]) > 0.01 and not success: 
        #         print(success, last_achieved - traj.obs.achieved_goal[IGNORE_FIRST], np.sum(traj.graph, axis=0))
        #         for i in range((len(traj) - her_traj_length - 1 if her_traj_length > 0 else 0), len(traj)):
        #             print(i, traj[i].obs.observation)
        
        her_add_statistics["her/achieved_desired"][-1] = [np.concatenate(her_add_statistics["her/achieved_desired"][-1] + [traj.obs.achieved_goal[IGNORE_FIRST] - last_achieved]).tolist()]
        her_add_statistics["her/num_traj"] += num_traj


    temp_buffer_ptrs[done_indices] = 0
    return her_add_statistics

class VectorGCHindsightReplayBufferManager(VectorGCReplayBufferManager):
    """VectorGCReplayBuffer contains n GCReplayBuffer with the same size.

    It is used for storing transition from different environments yet keeping the order
    of time.

    :param int total_size: the total size of VectorReplayBuffer.
    :param int buffer_num: the number of ReplayBuffer it uses, which are under the same
        configuration.

    Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail)
    are the same as :class:`~tianshou.data.ReplayBuffer`.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
    """
    def __init__(
            self,
            env,
            total_size: int,
            buffer_num: int,
            use_her: bool = False,
            her_use_count_select_goal: bool = False,
            horizon: int = 1,
            future_k: float = 8.0,
            reached_graph_threshold: int = 1,
            alpha: float = 0.6,
            policy_per_alpha: float = 0.6,
            dynamics_per_alpha: float = 0.6,
            beta: float = 0.4,
            weight_norm: bool = True,
            dynamics_per_priority_scale: float = 1.0,
            dynamics_per_update_count_scale: float = 0.0,
            dynamics_per_change_count_scale: float = 0.0,
            policy_per_td_error_scale: float = 1.0,
            policy_per_graph_count_scale: float = 1.0,
            policy_per_graph_count_power: float = -0.5,
            count_threshold_for_valid_graph: int = 0,
            decay_window: int = 5,
            decay_rate: float = 0.4,
            max_prev_decay: float = 0.7,
            her_ratio: float = 0.5,
            use_prio: bool = False,
            target_idx: int = 2,
            **kwargs: Any,
    ) -> None:
        super().__init__(env = env,
            total_size = total_size + int(total_size // buffer_num) * int(use_her),
            buffer_num = buffer_num + int(use_her),
            use_her = False,
            her_use_count_select_goal = her_use_count_select_goal,
            horizon = horizon,
            future_k = future_k,
            reached_graph_threshold = reached_graph_threshold,
            alpha = alpha,
            policy_per_alpha = policy_per_alpha,
            dynamics_per_alpha = dynamics_per_alpha,
            beta = beta,
            weight_norm = weight_norm,
            dynamics_per_priority_scale = dynamics_per_priority_scale,
            dynamics_per_update_count_scale = dynamics_per_update_count_scale,
            dynamics_per_change_count_scale = dynamics_per_change_count_scale,
            policy_per_td_error_scale = policy_per_td_error_scale,
            policy_per_graph_count_scale = policy_per_graph_count_scale,
            policy_per_graph_count_power = policy_per_graph_count_power,
            count_threshold_for_valid_graph = count_threshold_for_valid_graph,
            decay_window = decay_window,
            decay_rate = decay_rate,
            max_prev_decay = max_prev_decay,
            use_prio=use_prio,
            target_idx = target_idx,
            **kwargs)
        self.her_ratio = her_ratio
        self.single_total_size = total_size
        self.her_size = int(total_size // buffer_num) * int(use_her)
        self.factor_idx = target_idx if target_idx > 0 else target_idx + self.num_factors + 2
        if use_her:
            self.use_her = True
            self.her_buffer = self.buffers[-1]
            self.her_idx = buffer_num
            # self.buffer_num -= 1
            # self._extend_offset = np.array(self._offset)
            # self._offset = self._offset[:-1]
            # self._lengths = np.zeros_like(self._offset)

        #     # self.total_size = self.total_size * 2
        #     print("her buffer utils")
            # self.her_buffer = GCReplayBuffer(is_upper=False,
            #                                     size=total_size,
            #                                     use_her=False,
            #                                     horizon=horizon,
            #                                     future_k=future_k,
            #                                     alpha=alpha,
            #                                     beta=beta,
            #                                     weight_norm=weight_norm,
            #                                     num_factors=env.num_factors,
            #                                     count_threshold_for_valid_graph=count_threshold_for_valid_graph,
            #                                     prio=self.use_prio,
            #                                     **kwargs)
            self.init_HER_utils()
        #     # # PER weight (adding the her buffer)
        #     # tree_names = [f"policy_{factor_idx}" for factor_idx in range(env.num_factors)] + ["dynamics"] + ["her"]
        #     # GCReplayBuffer.init_weight_tree(self, tree_names)
            self.her_buffer.__eps = 1e-30
            self.reset_policy_PER()


    def init_logging_stats(self) -> None:
        # TODO: add any new logging stats here
        # for logging her resampling stats
        self.her_stats = [{"her_selected_count_idx": [],
                           "sampled_count_idx": []}
                          for _ in range(self.num_factors)]

    def sample(
            self,
            batch_size: int,
            policy_prio: bool = False,
            dynamics_prio: bool = False,
            no_her: bool = False,
            factor_idx: int = 1,
            her_update_achieved_goal: Callable = None,
    ) -> Tuple[Batch, np.ndarray]:
        """
        Replace Tianshou Sample to add no-prio, weights her reward parameters
        """
        assert not (policy_prio and dynamics_prio)
        if self.use_prio:
            assert not self.prio_cache.used_per, "must call update_weight() after each sample() call to clear the cache"
        no_her = no_her or (not self.use_her)

        self.prio_cache = Batch(used_per=policy_prio or dynamics_prio, tree_name=None)
        if policy_prio or dynamics_prio:
            self.weight = self.tree

        # print(hasattr(self, "her_buffer"), no_her, self.use_her)
        # print(no_her, len(self.her_buffer))
        if no_her or (not hasattr(self, "her_buffer") or len(self.her_buffer) < 1000): # wait until HER buffer has at least 1000 states
            if self.use_prio:
                indices = self.sample_indices(batch_size,
                                    use_prio=policy_prio or dynamics_prio)
            else:
                indices = self.sample_indices(batch_size)
            return self[indices], indices
        else:
            if batch_size == 0:
                indices = super().sample_indices(0)
                return self[indices], indices
            if int(np.round(batch_size * (1-self.her_ratio))) != 0:
                if self.use_prio:
                    indices = self.sample_indices(int(np.round(batch_size * (1-self.her_ratio))),
                                            use_prio=policy_prio or dynamics_prio)
                else:
                    indices = self.sample_indices(int(np.round(batch_size * (1-self.her_ratio))))
                # print("first", indices)
                if int(np.round(batch_size * (self.her_ratio))) == 0:
                    return self[indices], indices 
            else:
                if self.use_prio:
                    her_indices = self.sample_her_indices(int(np.round(batch_size * (self.her_ratio))),
                                            use_prio=False)
                else:
                    her_indices = self.sample_her_indices(int(np.round(batch_size * (self.her_ratio))))
                return self.her_buffer[her_indices], self.single_total_size + her_indices
            if int(np.round(batch_size * (self.her_ratio))) != 0: 
                if self.use_prio:
                    her_indices = self.sample_her_indices(int(np.round(batch_size * (self.her_ratio))),
                                            use_prio=False) # TODO: no priority on HER since her doesn't have a separate segtree
                else:
                    her_indices = self.sample_her_indices(int(np.round(batch_size * (self.her_ratio))))
                # print("her", her_indices, policy_prio or dynamics_prio)
            # print([indices[k].shape for k in self[indices].keys()], [her_indices[k].shape for k in self.her_buffer[her_indices].keys()])
            # print("indices", indices, self.single_total_size + her_indices)

            return Batch.cat([self[indices], self.her_buffer[her_indices]]), np.concatenate([indices, self.single_total_size + her_indices])

    def update_weight_indices(self, buffer, new_weight, index, alpha):
        # because of python name mangling, cant access self.__eps, we have to use buffer.__dict__['_PrioritizedReplayBuffer__eps'] instead of buffer.__eps
        weight = np.abs(to_numpy(new_weight)) + buffer.__dict__['_PrioritizedReplayBuffer__eps'] # assumes eps is the same for both
        # print(buffer.weight[index].shape, index.shape, weight.shape)
        main_indices = index[index < self.single_total_size]
        main_weight = weight[index < self.single_total_size]
        buffer.weight[main_indices] = main_weight ** alpha
        buffer._max_prio = max(buffer._max_prio, main_weight.max())
        buffer._min_prio = min(buffer._min_prio, main_weight.min())
        buffer.tree = buffer.weight
        
        # if self.use_her:
        #     her_indices = index[index >= self.single_total_size] - self.single_total_size
        #     her_weight = weight[index >= self.single_total_size]
        #     self.her_buffer.weight[her_indices] = her_weight ** alpha
        #     self.her_buffer._max_prio = max(self.her_buffer._max_prio, her_weight.max())
        #     self.her_buffer._min_prio = min(self.her_buffer._min_prio, her_weight.min())


    def new_weight_compute(self,buffer, new_weight, index):
        td_error = np.abs(to_numpy(new_weight))

        if self.policy_use_pser:
            # update weight for policy training
            n_steps = self.decay_window + 1
            n_step_indices = np.empty(n_steps * len(index), dtype=index.dtype)
            n_step_indices[0::n_steps] = index
            for i in range(1, n_steps):
                index = buffer.prev(index)
                n_step_indices[i::n_steps] = index
            # print(index.shape, n_step_indices.shape, new_weight.shape, td_error.shape)

            # new td_error_t ← max{|new_weight|, self.max_prev_decay * old_td_error_t}
            # for i in range(1, n_steps):
            #   new td_error_{t - i} ← max{|new_weight| * self.decay_rate ** i, old_td_error_{t - i}}
            td_error_buf = buffer.pser_stats

            # to avoid too fast decay:
            # new td_error_t ← max{|new_weight|, self.max_prev_decay * old_td_error_t}
            # if the data has never been sampled before old_td_error_t = 0
            n_step_old_td_error = td_error_buf[n_step_indices]
            n_step_old_td_error[0::n_steps] = self.max_prev_decay * n_step_old_td_error[0::n_steps]

            n_step_new_td_error = np.empty_like(n_step_old_td_error)
            n_step_new_td_error[0::n_steps] = td_error
            for i in range(1, n_steps):
                n_step_new_td_error[i::n_steps] = td_error * self.decay_rate ** i

            n_step_new_td_error = np.maximum(n_step_new_td_error, n_step_old_td_error)
            td_error = n_step_new_td_error[0::n_steps]

            # filter out repeated indices in n_step_indices
            # TODO: do not consider repeated indices in index
            unique_mask = np.concatenate([[True], n_step_indices[1:] != n_step_indices[:-1]])
            index = n_step_indices[unique_mask]
            new_weight = td_error_buf[index] = n_step_new_td_error[unique_mask]
            # print(new_weight.shape, td_error_buf[index].shape, index.shape, unique_mask.shape, n_step_new_td_error.shape)

        # if policy_idx == self.num_factors, it's state coverage lower policy which doesn't need graph count
        if self.policy_per_use_graph_count:
            # compute td_error moving average to scale count_weight
            td_error_stats_i = self.td_error_stats[self.factor_idx]
            td_error_stats_i.add(td_error.mean())

            count_weight = buffer.compute_graph_count_weight(index, self.factor_idx) #TODO: might have issues since her counts != full buffer counts
            new_weight = new_weight * count_weight / td_error_stats_i.mean

            new_weight = self.policy_per_td_error_scale * new_weight + self.policy_per_graph_count_scale * count_weight
        return new_weight, index

    def update_weight(
        self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]
    ) -> None:
        assert index.ndim == 1, "index, i.e., batch size must be 1D"

        if self.prio_cache.used_per:
            super().update_weight(index[index < self.single_total_size], new_weight[index < self.single_total_size])
            # treat prioritized replay trees as separate for her and normal sampling
            new_weight = to_numpy(new_weight) # move the weights to memory to prevent cuda errors
            alpha = self.policy_per_alpha
            # main_idx = index[index < self.single_total_size]
            # new_weight[main_idx] = self.new_weight_compute(self, new_weight[main_idx], main_idx)
            # print(new_weight.reduce())
            # self.update_weight_indices(self, new_weight, main_idx, alpha)
            # her_idx = index[index >= self.single_total_size] - self.single_total_size
            # new_weight[her_idx] = self.new_weight_compute(self.her_buffer, new_weight[her_idx], her_idx)
            # self.update_weight_indices(self.her_buffer, new_weight, her_idx, alpha)

            new_weight, index = self.new_weight_compute(self, new_weight, index)
            self.update_weight_indices(self, new_weight, index, alpha)

            self.prio_cache = Batch(used_per=False, tree_name=None)
            # print(self.weight.reduce(), self.her_buffer.weight.reduce())

    def sample_indices(
            self,
            batch_size: int,
            use_prio: bool = False,
            factor_idx: int = 1
    ) -> np.ndarray:
        self._restore_cache()
        if use_prio:
            indices = PrioritizedReplayBuffer.sample_indices(self, batch_size)    # sample with priority
        else:
            # set the lengths to ignore the HER buffer
            lengths, offsets, buffer_num = self._lengths, self._offset, self.buffer_num
            self._lengths, self._offset, self.buffer_num = self._lengths[:-1], self._offset[:-1], self.buffer_num - 1
            indices = ReplayBufferManager.sample_indices(self, batch_size)
            self._lengths, self._offset, self.buffer_num = lengths, offsets, buffer_num
        # print("normal", indices, len(self))
        return indices

    def sample_her_indices(
            self,
            batch_size: int,
            use_prio: bool = False,
            factor_idx: int = 1
    ) -> np.ndarray:
        self._restore_cache()
        if use_prio:
            indices = PrioritizedReplayBuffer.sample_indices(self.her_buffer, batch_size)    # sample with priority
        else:
            indices = ReplayBuffer.sample_indices(self.her_buffer, batch_size)

        # print("her", indices, len(self.her_buffer))
        return indices

    def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
        # TODO: for now, indices MUST be normal buffer first, her next, otherwise the return order will be different
        if isinstance(index, slice):  # change slice to np array
            # buffer[:] will get all available data
            indices = self.sample_indices(0) if index == slice(None) \
                else self._indices[:len(self)][index]
        else:
            indices = index  # type: ignore
        # print("before indices", indices)
        # before_indices = copy.deepcopy(indices)
        her_indices = indices[indices >= self.single_total_size] - self.single_total_size
        indices = indices[indices < self.single_total_size]
        # print(before_indices, indices, her_indices)
        if len(indices) > 0:
            batch = super().__getitem__(indices)
            for k in self._extra_sample_keys:
                batch.__dict__[k] = self._meta[k][indices]
        else:
            her_batch = self.her_buffer.__getitem__(her_indices)
            for k in self._extra_sample_keys:
                her_batch.__dict__[k] = self.her_buffer._meta[k][her_indices]
            return her_batch
        if len(her_indices) > 0:
            her_batch = self.her_buffer.__getitem__(her_indices)
            # print(np.sum(self.her_buffer.sample(0)[0].obs.observation[:100] - self.sample(0)[0].obs.observation[:100]))
            # print(np.sum(self.her_buffer.sample(0)[0].obs.desired_goal[:100] - self.sample(0)[0].obs.desired_goal[:100]))
            # print(np.sum(self.her_buffer.sample(0)[0].act[:100] - self.sample(0)[0].act[:100]))
            # print(np.sum(self.her_buffer.sample(0)[0].rew[:100] - self.sample(0)[0].rew[:100]))
            # print(batch, her_batch, Batch.cat([her_batch, batch]))
            # error
            return Batch.cat([batch, her_batch])
        return batch

    def add_her(
            self,
            batch):
        self.her_buffer.update_graph_count(batch.graph_count_idx)
        ptrs, ep_rews, ep_lens, ep_idxs = self.her_buffer.add(batch)
        self.num_dynamics_updates[ptrs] = 0
        self._lengths[self.her_idx] = len(self.her_buffer)
        return ptrs, ep_rews, ep_lens, ep_idxs

    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        self._restore_cache()

        self.update_graph_count(batch.graph_count_idx)

        if self.dynamics_per_use_change_count:
            variable_changed = self.get_if_variable_changes(batch)
            self.variable_change_count += variable_changed.sum(axis=0)
            self.variable_no_change_count += (~variable_changed).sum(axis=0)

        # if self.use_her and self.her_use_episode_tracker:
        #     self.update_episode_start_end_indices(batch, buffer_ids)

        ptrs, ep_rews, ep_lens, ep_idxs = super().add(batch, buffer_ids)
        # print("add", ptrs, batch.rew, ep_rews, ep_lens, ep_idxs)
        self.num_dynamics_updates[ptrs] = 0
        return ptrs, ep_rews, ep_lens, ep_idxs

    def logging_her_stats(self, writer: SummaryWriter, step: int) -> bool:
        # writes the her resampling statistics to the summaryWriter
        # TODO: we could also just write to a file location with a step counter
        logged = False

        for factor, stats in enumerate(self.her_stats):
            # the sample rate for each factor
            her_selected_count_idx = stats["her_selected_count_idx"]
            sampled_count_idx = stats["sampled_count_idx"]
            if not her_selected_count_idx or not sampled_count_idx:
                continue

            logged = True

            # convert samples to (unique, count)
            sampled_count_idxes, sampled_count = np.unique(sampled_count_idx, return_counts=True)
            her_selected_count_idxes, her_selected_counts = np.unique(her_selected_count_idx, return_counts=True)
            sample_stats = dict(zip(sampled_count_idxes, sampled_count))
            her_stats = dict(zip(her_selected_count_idxes, her_selected_counts))

            # add graph that has never been sampled
            for j, count in enumerate(self.valid_graph_count[factor]):
                if count > 0:
                    if j not in her_selected_count_idxes:
                        her_stats[j] = 0
                    if j not in sampled_count_idxes:
                        sample_stats[j] = 0

            # total number of values sampled with her
            total_her_count = np.sum(list(her_stats.values()))
            total_sampled_count = np.sum(list(sample_stats.values()))
            total_graph_count = np.sum(self.valid_graph_count[factor])

            graph_names = []
            her_selected_count_percents = []
            sampled_count_percents = []
            graph_count_percents = []

            factor_names = self.extractor.factor_names + ["act"]
            for count_idx, her_selected_count in sorted(her_stats.items(), key=lambda item: item[1]):
                # for each different graph, indicate the number sampled for that graph
                her_selected_count_percents.append(100 * her_selected_count / total_her_count)
                sampled_count_percents.append(100 * sample_stats.get(count_idx, 0) / total_sampled_count)

                parents = self.count_idx_to_graph[count_idx].astype(bool)
                graph_name = ", ".join([factor_names[j] for j, p in enumerate(parents)
                                        if p])
                graph_name = graph_name + " -> " + factor_names[factor]
                graph_names.append(graph_name)

                graph_count = self.valid_graph_count[factor, count_idx]
                graph_count_percents.append(100 * graph_count / total_graph_count)

            # num_graphs = len(graph_names)
            # fig = plt.figure(figsize=(10, max(num_graphs * 1, 3)))

            # # plot her resampling graph frequency
            # ax = plt.gca()
            # y = np.arange(num_graphs)

            # height = 0.2
            # rects = ax.barh(y + height, her_selected_count_percents,  height=height, label="her relabeling frequency")
            # ax.bar_label(rects, label_type='edge', fmt="%.3f", padding=3)
            # rects = ax.barh(y, sampled_count_percents, height=height, label="PER sampling frequency")
            # ax.bar_label(rects, label_type='edge', fmt="%.3f", padding=3)

            # # plot total graph count info
            # rects = ax.barh(y - height, graph_count_percents, height=height, label="percentage in buffer")
            # ax.bar_label(rects, label_type='edge', fmt="%.3f", padding=3)

            # plt.xlim([0, 1.1 * max(np.max(her_selected_count_percents),
            #                        np.max(sampled_count_percents),
            #                        np.max(graph_count_percents))])

            # ax.set_yticks(y)
            # ax.set_yticklabels(graph_names)

            # plt.legend(loc="lower right")
            # fig.tight_layout()
            # writer.add_figure(f"her_stats_{factor}", fig, step)
            # plt.close("all")

        self.her_stats = [{"her_selected_count_idx": [],
                           "sampled_count_idx": []}
                          for _ in range(self.num_factors)]

        return logged
    # def get_upper_indices(
    #     self,
    #     buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    # ) -> np.ndarray:
    #     ptrs = []
    #     for buffer_id in buffer_ids:
    #         buffer = self.buffers[buffer_id]
    #         if len(buffer) == buffer.maxsize:
    #             ptrs.append(buffer._index + self._offset[buffer_id])
    #     ptrs = np.array(ptrs, dtype=int)
    #     if "upper_buffer_index" in self._meta:
    #         upper_indices = self.upper_buffer_index[ptrs]
    #     else:
    #         # lower buffer data not initialized yet
    #         upper_indices = np.array([], dtype=int)
    #     return upper_indices

    def get_buffer_indices(self, start, end):
        # end is inclusive
        if end >= start:
            return np.arange(start, end + 1)
        else:
            buffer_start = buffer_end = -1
            for buffer_start, buffer_end in zip(self._offset, self._extend_offset[1:]):
                if buffer_start <= start and buffer_end > end:
                    break
            assert buffer_start != -1 and buffer_end != -1
            return np.concatenate(
                    [np.arange(start, buffer_end),
                     np.arange(buffer_start, end + 1)]
                )
    
    def init_her_utils(self) -> None:
        if not self.use_her:
            return

        self.her_use_episode_tracker = True
        if self.her_use_episode_tracker:
            # to get episode start / end indices faster than iteratively call self.prev() / self.next()
            # data structure:
            #   episode_ptrs:
            #       (total_size, ) int, the index of the episode that the transition belongs to
            #       for easier indexing, we use the index of the first ADDED-to-buffer transition of the episode
            #   episode_start_index: (total_size, ) int, the index of the first transition of the episode, inclusive
            #   episode_end_index: (total_size, ) int, the index of the last transition of the episode, exclusive
            self.episode_ptrs = np.full(self.total_size, -1, dtype=int)
            self.episode_start_index = np.full(self.total_size, -1, dtype=int)
            self.episode_end_index = np.full(self.total_size, -1, dtype=int)

            # keep track of if to-add indices is the first transition of the episode
            self.if_prev_episode_ends = np.ones(self.buffer_num, dtype=bool)
            self.current_episode_ptrs = np.copy(self._offset)

    def reset_policy_PER(self) -> None:
        print(self.total_size, self.maxsize)
        if self.policy_use_pser:
            self.pser_stats = [np.zeros(self.total_size, dtype=np.float32)
                               for _ in range(self.num_factors)]
        
        if hasattr(self, "her_buffer"):
            self.maxsize = self.maxsize - self.her_size
            GCReplayBuffer.init_weight_tree(self, [f"policy_{i}" for i in range(self.num_factors)])
            self.maxsize = self.maxsize + self.her_size
            GCReplayBuffer.init_weight_tree(self.her_buffer, [f"policy_{i}" for i in range(self.num_factors)])

        index = self.sample_indices(0)      # all activate indices
        print("base buffer indices", len(index))
        if len(index) != 0:
            for i in range(self.num_factors):
                self.weight = self.tree
                if i < self.num_factors:
                    weight = self.compute_graph_count_weight(index, i) + np.finfo(np.float32).tiny
                    weight *= self._max_prio / weight.max()         # scale weight so that least visited graph has max prio
                else:
                    weight = np.ones_like(index) * self._max_prio
                super().update_weight(index, weight)
                self.tree = self.weight

    # def update_upper_buffer_idx(self, start, end, upper_buffer_idx):
    #     idx = self.get_lower_buffer_indices(start, end)
    #     self.upper_buffer_index[idx] = upper_buffer_idx

    # below copied from herReplayBufferManager
    def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
        return super().save_hdf5(path, compression)

    def set_batch(self, batch: Batch) -> None:
        return super().set_batch(batch)

    def update(self, buffer) -> np.ndarray:
        return super().update(buffer)

    def restore_cache(self):
        pass

    # def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
    #     if isinstance(index, (list, np.ndarray)):
    #         prev_indices = np.zeros_like(index)
    #         prev_indices[index >= self.single_total_size] = self.her_buffer.prev(index[index >= self.single_total_size] - self.single_total_size) + self.single_total_size
    #         prev_indices[index < self.single_total_size] = super().prev(index[index < self.single_total_size])
    #         return prev_indices
    #     else:
    #         if index >= self.single_total_size: return self.her_buffer.prev(index - self.single_total_size) + self.single_total_size
    #         else: super().prev(index)

    # def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
    #     if isinstance(index, (list, np.ndarray)):
    #         next_indices = np.zeros_like(index)
    #         next_indices[index >= self.single_total_size] = self.her_buffer.next(index[index >= self.single_total_size] - self.single_total_size) + self.single_total_size
    #         next_indices[index < self.single_total_size] = super().next(index[index < self.single_total_size])
    #         return next_indices
    #     else:
    #         if index >= self.single_total_size: return self.her_buffer.next(index - self.single_total_size) + self.single_total_size
    #         else: super().next(index)


    # def __len__(self) -> int:
    #     print(self.her_buffer._size)
    #     return self._size + self.her_buffer._size

    # def __repr__(self) -> str:
    #     return super().__repr__() + self.her_buffer.__repr__()

    # def __getattr__(self, key: str) -> Any:
    #     start = time.time()
    #     if type(super().__getattr__(key)) == np.ndarray:
    #         batch= np.concatenate([super().__getattr__(key), self.her_buffer.__getattr__(key)], axis=0)
    #     else:
    #         batch= Batch.cat([super().__getattr__(key), self.her_buffer.__getattr__(key)])
    #     print(key, time.time() - start)
    #     return batch

    # def get_main(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch:
    #     """gets a batch from the main, not HER buffer
    #     """
    #     if isinstance(index, slice):  # change slice to np array
    #         # buffer[:] will get all available data
    #         indices = self.sample_indices(0) if index == slice(None) \
    #             else self._indices[:len(self)][index]
    #     else:
    #         indices = index  # type: ignore
    #     # raise KeyError first instead of AttributeError,
    #     # to support np.array([ReplayBuffer()])
    #     obs = self.get(indices, "obs")
    #     if self._save_obs_next:
    #         obs_next = self.get(indices, "obs_next", Batch())
    #     else:
    #         obs_next = self.get(self.next(indices), "obs", Batch())
    #     batch_dict = {
    #         "obs": obs,
    #         "act": self._meta["act"][indices],
    #         "rew": self._meta["rew"][indices],
    #         "terminated": self._meta["terminated"][indices],
    #         "truncated": self._meta["truncated"][indices],
    #         "done": self._meta["done"][indices],
    #         "obs_next": obs_next,
    #         "info": self.get(indices, "info", Batch()),
    #         "policy": self.get(indices, "policy", Batch()),
    #     }
    #     for key in self._meta.__dict__.keys():
    #         if key not in self._input_keys:
    #             batch_dict[key] = self._meta[key][indices]
    #     return Batch(batch_dict)


    # def sample_indices(self, batch_size: int) -> np.ndarray:
    #     if batch_size < 0:
    #         return np.array([], int)
    #     if self._sample_avail and self.stack_num > 1:
    #         all_indices = np.concatenate(
    #             [
    #                 buf.sample_indices(0) + offset
    #                 for offset, buf in zip(self._offset, self.buffers)
    #             ]
    #         )
    #         if batch_size == 0:
    #             return all_indices
    #         else:
    #             return np.random.choice(all_indices, batch_size)
    #     if batch_size == 0:  # get all available indices
    #         sample_num = np.zeros(self.buffer_num, int)
    #     else:
    #         buffer_idx = np.random.choice(
    #             self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
    #         )
    #         sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
    #         # avoid batch_size > 0 and sample_num == 0 -> get child's all data
    #         sample_num[sample_num == 0] = -1

    #     return np.concatenate(
    #         [
    #             buf.sample_indices(bsz) + offset
    #             for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
    #         ]
    #     )
